# -*- coding: utf-8 -*-
import argparse
import json
import time
import numpy as np
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt
from keras.utils.np_utils import to_categorical
from keras.optimizers import SGD
import keras.backend as K
import pandas as pd
from utils.Models import choose_model
from DataUtils.Load_util import load_local_train_val_DEBUG, select_drivers, load_test, load_local_train_val, \
    load_local_train_val_multiscale, new_select_drivers, resample_by_sortedID
from DataUtils.Preprocessing import preprocessing, data_augmentation
from sklearn.utils import shuffle
from keras.preprocessing.image import ImageDataGenerator
import os


def set_configs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train', help='train, submit_multi, submit_select')
    parser.add_argument('-name', '--experiment_name', type=str, default='baseline')
    parser.add_argument('--to_gray', type=str, default='False', help='whether transform the img to gray form')
    parser.add_argument('--resize', type=tuple, default=(224, 224),
                        help='resize the img')  # be attention to the cv2's size property is not same to numpy's
    parser.add_argument('-WP', '--weights_path', type=str, default='vgg16.h5', help='path to exist weights')
    parser.add_argument('-M', '--model_type', type=str, default='VGG_16',
                        help='model type: VGG_16 KT(keras_cifar10_exa) KB(keras baseline)')
    parser.add_argument('-O', '--optimizer', type=str, default='SGD', help='SGD adam')
    parser.add_argument('--loss', type=str, default='categorical_crossentropy', help='loss function')
    parser.add_argument('-NE', '--num_epochs', type=int, default=5, help='epochs')
    parser.add_argument('--pre_process', type=str, default=None)
    parser.add_argument('-D', '--debug', type=str, default='False')
    parser.add_argument('-CC', '--crop_center', type=str, default='True')
    parser.add_argument('-CF', '--classifier', type=str, default='softmax', help='softmax sigmoid')
    parser.add_argument('-TConv', '--trainable_conv', type=str, default='False')
    parser.add_argument('-TFC', '--trainable_fc', type=str, default='True')
    parser.add_argument('-BS', '--batch_size', type=int, default=16)
    parser.add_argument('-CR', '--crop_right', type=str, default='False')
    parser.add_argument('-Aug', '--augmentation', type=str, default='False')
    parser.add_argument('-ConTi', '--continue', type=str, default='False')
    parser.add_argument('-DMTr', '--Data_Mode_Tr', type=str, default='single_scale')
    config = parser.parse_args()
    config = vars(config)
    file_descriptions = 'BatchNorm' + '_Md-' + config['model_type'] + '_E-' + str(config['num_epochs']) \
                        + '_BS-' + str(config['batch_size']) + '_CR-' + config['crop_right'] + '_CC-' + config[
                            'crop_center'] \
                        + '_Aug-' + config['augmentation'] + '_ConTi-' + config['continue'] \
                        + '_L-' + config['loss'] + '_CF-' + config['classifier']
    return config, file_descriptions


def train_ddd(config, file_descriptions):
    # theano_ = True
    with K.tf.device('/gpu:0'):
        # if theano_ == True:
        # K.set_session(K.tf.Session(config=K.tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)))
        model = choose_model(config)
        opti = None
        loss = None
        if config['optimizer'] == 'SGD':
            opti = SGD(lr=1e-5, decay=1e-6, momentum=0.9, nesterov=True)
        elif config['optimizer'] == 'adam':
            opti = 'adam'
        if config['loss'] == 'categorical_crossentropy':
            loss = 'categorical_crossentropy'
        model.compile(optimizer=opti, loss=loss, metrics=['accuracy'])
        if config['continue'] == 'True':
            print 'continue!'
            model.load_weights(
                '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/experiments/224x224/baseline/'
                'Md-VGG_16_E-3_BS-16_CR-False_CC-True_Aug-False_ConTi-True_L-categorical_crossentropy_CF-softmax/e-2.h5')
        # processing data
        image_narrays = None
        lable_narrays = None
        drivers_id = None
        unique_drivers = None
        if config['debug'] == 'False' and config['Data_Mode_Tr'] == 'single_scale':
            image_narrays, lable_narrays, drivers_id, unique_drivers = load_local_train_val(config=config)
        elif config['debug'] == 'True' and config['Data_Mode_Tr'] == 'single_scale':
            image_narrays, lable_narrays, drivers_id, unique_drivers = load_local_train_val_DEBUG(config=config)
        elif config['Data_Mode_Tr'] == 'multi_scale':
            image_narrays, lable_narrays, drivers_id, unique_drivers = load_local_train_val_multiscale(config=config)
        mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
        reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
        unique_list_train = unique_drivers[0:-1]
        unique_list_val = unique_drivers[-1]
        num_samples = len(lable_narrays)
        print '{0} train_val samples'.format(num_samples)
        print '{0} drivers'.format(len(unique_drivers))
        label_matrix = to_categorical(lable_narrays, nb_classes=10)
        train_X, train_Y, train_drivers_id, index = new_select_drivers(image_narrays, label_matrix, drivers_id,
                                                                       unique_list_train)
        val_X, val_Y, val_drivers_id, index = new_select_drivers(image_narrays, label_matrix, drivers_id,
                                                                 unique_list_val)
        train_X, train_Y, train_drivers_id = shuffle(train_X, train_Y, train_drivers_id)
        train_X, train_Y, train_drivers_id = resample_by_sortedID(train_X, train_Y, train_drivers_id)
        val_X, val_Y, val_drivers_id = resample_by_sortedID(val_X, val_Y, val_drivers_id)
        train_X = train_X[:] - reshaped_mean_vec
        val_X = val_X[:] - reshaped_mean_vec
        print '{0} train samples'.format(len(train_Y))
        print '{0} val samples'.format(len(val_Y))

        # train
        loss = []
        val_loss = []
        acc = []
        val_acc = []
        time_counter = []
        train_log_loss = []
        val_log_loss = []
        for e in range(config['num_epochs']):
            epoch_begin_time = time.time()
            train_X = preprocessing(train_X, config)

            if config['augmentation'] == 'False':
                hist = model.fit(x=train_X, y=train_Y, batch_size=config['batch_size'], nb_epoch=1,
                                 validation_data=(val_X, val_Y), shuffle=True)
            elif config['augmentation'] == 'True':
                # this will do preprocessing and realtime data augmentation
                datagen = ImageDataGenerator(
                    featurewise_center=False,  # set input mean to 0 over the dataset
                    samplewise_center=False,  # set each sample mean to 0
                    featurewise_std_normalization=False,  # divide inputs by std of the dataset
                    samplewise_std_normalization=False,  # divide each input by its std
                    zca_whitening=False,  # apply ZCA whitening
                    rotation_range=10,  # randomly rotate images in the range (degrees, 0 to 180)
                    width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
                    height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
                    horizontal_flip=False,  # randomly flip images
                    vertical_flip=False)  # randomly flip images

                # compute quantities required for featurewise normalization
                # (std, mean, and principal components if ZCA whitening is applied)
                datagen.fit(train_X)

                hist = model.fit_generator(datagen.flow(train_X, train_Y, batch_size=config['batch_size']),
                                           samples_per_epoch=train_X.shape[0],
                                           nb_epoch=1, validation_data=(val_X, val_Y))
            experiment_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD' \
                              + '/experiments/' + str(config['resize'][0]) + 'x' + str(config['resize'][1]) \
                              + '/' + config['experiment_name'] + '/'
            if not os.path.exists(experiment_path + file_descriptions):
                os.mkdir(experiment_path + file_descriptions)
            model.save_weights(experiment_path + file_descriptions + '/e-' + str(e) + '.h5', overwrite=True)
            epoch_end_time = time.time()
            time_counter += [-epoch_begin_time + epoch_end_time]
            loss += hist.history['loss']
            val_loss += hist.history['val_loss']
            acc += hist.history['acc']
            val_acc += hist.history['val_acc']
            # fig = plt.figure()
            """
            if e == 0 or (e+1) % 5 == 0:
                train_Y_pred = model.predict_proba(train_X)
                train_ll = log_loss(train_Y, train_Y_pred)
                train_log_loss += [train_ll]

                val_Y_pred = model.predict_proba(val_X)
                val_ll = log_loss(val_Y, val_Y_pred)
                val_log_loss += [val_ll]
                print 'train_log_loss: ' + str(train_ll)
                print 'val_log_loss: ' + str(val_ll)
            """
            print 'epoch {0} done'.format(e + 1)
            print '-' * 50
        training = {
            'loss': loss,
            'val_loss': val_loss,
            'acc': acc,
            'val_acc': val_acc,
            'log_loss': train_log_loss,
            'val_log_loss': val_log_loss,
            'optimizer': model.optimizer.get_config(),
            'time': time_counter
        }
        metadata = {
            'training': training,
            'config': config
        }

        f = open(experiment_path + file_descriptions + '.json', 'wb')
        meta_json = json.dumps(metadata, default=lambda o: o.__dict__, indent=4)
        f.write(meta_json)
        f.close()


def submit_select(config, file_descriptions):
    epochs = [0, 1, 2, 3, 4]
    img_narrays, img_names = load_test(config=config)
    mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
    reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
    img_narrays -= reshaped_mean_vec
    with K.tf.device('/gpu:0'):
        model = choose_model(config)
        opti = None
        loss = None
        if config['optimizer'] == 'SGD':
            opti = SGD(lr=0.1, momentum=0.9, nesterov=True)
        elif config['optimizer'] == 'adam':
            opti = 'adam'
        if config['loss'] == 'categorical_crossentropy':
            loss = 'categorical_crossentropy'

        model.compile(optimizer=opti, loss=loss)
        experiment_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD' \
                          + '/experiments/' + str(config['resize'][0]) + 'x' + str(config['resize'][1]) \
                          + '/' + config['experiment_name'] + '/'
        h5_folder = experiment_path + file_descriptions
        for i in epochs:
            model_path = h5_folder + '/' + 'e-' + str(i) + '.h5'
            model.load_weights(model_path)
            pred = model.predict_proba(img_narrays, batch_size=config['batch_size'])

            result = pd.DataFrame(pred, columns=['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9'])
            result.loc[:, 'img'] = pd.Series(img_names, index=result.index)
            submit_folder = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD' \
                            + '/experiments/submissions/' + str(config['resize'][0]) + 'x' + str(config['resize'][1]) \
                            + config['experiment_name'] + file_descriptions
            if not os.path.exists(submit_folder):
                os.mkdir(submit_folder)
            submit_file = submit_folder + '/' + 'e-' + str(i) + '_' + file_descriptions + '.csv'
            result.to_csv(submit_file, index=False)


def submit_multi(config, file_descriptions):
    img_narrays, img_names = load_test(config=config)
    mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
    reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
    img_narrays -= reshaped_mean_vec
    with K.tf.device('/gpu:0'):
        model = choose_model(config)
        opti = None
        loss = None
        if config['optimizer'] == 'SGD':
            opti = SGD(lr=0.1, momentum=0.9, nesterov=True)
        elif config['optimizer'] == 'adam':
            opti = 'adam'
        if config['loss'] == 'categorical_crossentropy':
            loss = 'categorical_crossentropy'

        model.compile(optimizer=opti, loss=loss)
        experiment_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD' \
                          + '/experiments/' + str(config['resize'][0]) + 'x' + str(config['resize'][1]) \
                          + '/' + config['experiment_name'] + '/'
        h5_folder = experiment_path + file_descriptions
        for i in range(config['num_epochs']):
            model_path = h5_folder + '/' + 'e-' + str(i) + '.h5'
            model.load_weights(model_path)
            pred = model.predict_proba(img_narrays, batch_size=config['batch_size'])

            result = pd.DataFrame(pred, columns=['c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9'])
            result.loc[:, 'img'] = pd.Series(img_names, index=result.index)
            submit_folder = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD' \
                            + '/experiments/submissions/' + str(config['resize'][0]) + 'x' + str(config['resize'][1]) \
                            + config['experiment_name'] + file_descriptions
            if not os.path.exists(submit_folder):
                os.mkdir(submit_folder)
            submit_file = submit_folder + '/' + 'e-' + str(i) + '_' + file_descriptions + '.csv'
            result.to_csv(submit_file, index=False)


if __name__ == '__main__':
    config, file_descriptions = set_configs()
    if config['mode'] == 'train':
        train_ddd(config, file_descriptions)
    elif config['mode'] == 'submit_multi':
        submit_multi(config, file_descriptions)
    elif config['mode'] == 'submit_select':
        submit_select(config, file_descriptions)
    else:
        print 'wrong mode: expect train or submit'
